{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 293,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"A very simple MNIST classifier.\n",
    "See extensive documentation at\n",
    "https://www.tensorflow.org/get_started/mnist/beginners\n",
    "\"\"\"\n",
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import argparse\n",
    "import sys\n",
    "\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "FLAGS = None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 294,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /tmp/tensorflow/mnist/input_data/train-images-idx3-ubyte.gz\n",
      "Extracting /tmp/tensorflow/mnist/input_data/train-labels-idx1-ubyte.gz\n",
      "Extracting /tmp/tensorflow/mnist/input_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting /tmp/tensorflow/mnist/input_data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "# Import data\n",
    "data_dir = '/tmp/tensorflow/mnist/input_data'\n",
    "mnist = input_data.read_data_sets(data_dir, one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 295,
   "metadata": {},
   "outputs": [],
   "source": [
    "REGULARIZATION_RATE = 0.001  # 正则化项在损失函数中的系数\n",
    "\n",
    "x = tf.placeholder(tf.float32, [None, 784])\n",
    "y = tf.placeholder(tf.float32, [None, 10])\n",
    "y_ = tf.placeholder(tf.float32, [None, 10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 296,
   "metadata": {},
   "outputs": [],
   "source": [
    "W1 = tf.Variable(tf.truncated_normal([784, 600], stddev=0.1), name='W1')\n",
    "b1 = tf.Variable(tf.zeros([600]) + 0.1, name='b1')\n",
    "L1 = tf.nn.tanh(tf.matmul(x, W1) + b1, name='L1')\n",
    "\n",
    "W2 = tf.Variable(tf.truncated_normal([600, 500], stddev=0.1), name='W2')\n",
    "b2 = tf.Variable(tf.zeros([500]) + 0.1, name='b2')\n",
    "L2 = tf.nn.tanh(tf.matmul(L1, W2) + b2, name='L2')\n",
    "\n",
    "W3 = tf.Variable(tf.truncated_normal([500, 400], stddev=0.1), name='W3')\n",
    "b3 = tf.Variable(tf.zeros([400]) + 0.1, name='b3')\n",
    "L3 = tf.nn.tanh(tf.matmul(L2, W3) + b3, name='L3')\n",
    "\n",
    "W4 = tf.Variable(tf.truncated_normal([400, 10], stddev=0.1), name='W4')\n",
    "b4 = tf.Variable(tf.zeros([10]) + 0.1, name='b4')\n",
    "\n",
    "y = tf.nn.softmax(tf.matmul(L3, W4) + b4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 297,
   "metadata": {},
   "outputs": [],
   "source": [
    "regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)  \n",
    "regularization = regularizer(W1) + regularizer(W2) + regularizer(W3) + regularizer(W4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 298,
   "metadata": {},
   "outputs": [],
   "source": [
    "cross_entropy = tf.reduce_mean(\n",
    "tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 299,
   "metadata": {},
   "outputs": [],
   "source": [
    "LEARNING_RATE_BASE = 0.8  # 基础学习率\n",
    "LEARNING_RATE_DECAY = 0.99  # 学习率的衰减率\n",
    "global_step=tf.Variable(0,trainable=False)\n",
    "learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.train.num_examples / 100,LEARNING_RATE_DECAY)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 300,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)\n",
    "\n",
    "sess = tf.Session()\n",
    "init_op = tf.global_variables_initializer()\n",
    "sess.run(init_op)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 301,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train\n",
    "for _ in range(30000):\n",
    "  batch_xs, batch_ys = mnist.train.next_batch(100)\n",
    "  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 302,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9802\n"
     ]
    }
   ],
   "source": [
    "  # Test trained model\n",
    "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "print(sess.run(accuracy, feed_dict={x: mnist.test.images,\n",
    "                                      y_: mnist.test.labels}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
